import torch
import matplotlib.pyplot as plt

fig = plt.figure()

w1 = torch.linspace(-1, 1, 40, requires_grad=True)
w2 = torch.linspace(-1, 1, 40, requires_grad=True)

w1, w2 = torch.meshgrid([w1, w2], indexing="ij")
e = w1 ** 2 + 2 * w2 ** 2

ax1 = fig.add_subplot(121)
ax1.contour(w1.detach().numpy(), w2.detach().numpy(), e.detach().numpy())
ax2 = fig.add_subplot(122, projection="3d")
ax2.plot_surface(w1.detach().numpy(), w2.detach().numpy(), e.detach().numpy(), cmap=plt.cm.YlGnBu_r)

plt.show()
